feat(gdn): separate input and output pool indices#2905
feat(gdn): separate input and output pool indices#2905feldsherov wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughAdds an optional Changes
Sequence DiagramsequenceDiagram
participant User as User Code
participant API as gated_delta_rule_decode_pretranspose
participant Runner as run_pretranspose_decode / gated_delta_rule_mtp
participant Kernel as CUDA Kernel (pretranspose / bf16 MTP)
participant Pool as State Pool
User->>API: call with initial_state (+pool), initial_state_indices (read), output_state_indices (write)
API->>API: validate output_state_indices (pool-mode, shape [B], int32/64)
API->>Runner: forward tensors and indices
Runner->>Kernel: launch with h0_indices (read) and h0_out_indices (write)
Kernel->>Pool: read state from Pool[h0_indices[b]]
Kernel->>Kernel: compute gated-delta updates
Kernel->>Pool: write updated state to Pool[h0_out_indices[b]]
Pool-->>User: outputs and mutated pool
Estimated code review effort🎯 4 (Complex) | ⏱️ ~40 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces the output_state_indices parameter to the Gated Delta Rule decode kernels, enabling the updated state to be written to a different pool slot than the one read from. The changes include updates to the high-level Python API, the underlying CUDA kernels in gdn_decode_bf16_state.py and gdn_decode_pretranspose.py, and the addition of verification tests. Review feedback recommends grouping index reads within the kernels to improve consistency and instruction-level parallelism, as well as simplifying redundant logic in the pretranspose runner.
| pool_batch_idx = gH_slot_indices[batch_idx] | ||
| if pool_batch_idx < 0: | ||
| pool_batch_idx = cutlass.Int32(0) | ||
| write_pool_batch_idx = gH_out_slot_indices[batch_idx] | ||
| if write_pool_batch_idx < 0: | ||
| write_pool_batch_idx = cutlass.Int32(0) |
There was a problem hiding this comment.
This logic for handling negative indices is duplicated in gated_delta_rule_decode_kernel_seqlen234_unified and gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk. To improve consistency and potentially instruction-level parallelism, consider grouping the reads together before the checks, as done in gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk.
Additionally, to reduce code duplication across kernels, you could introduce a cute.jit helper function at the module level to handle this pattern.
| pool_batch_idx = gH_slot_indices[batch_idx] | |
| if pool_batch_idx < 0: | |
| pool_batch_idx = cutlass.Int32(0) | |
| write_pool_batch_idx = gH_out_slot_indices[batch_idx] | |
| if write_pool_batch_idx < 0: | |
| write_pool_batch_idx = cutlass.Int32(0) | |
| pool_batch_idx = gH_slot_indices[batch_idx] | |
| write_pool_batch_idx = gH_out_slot_indices[batch_idx] | |
| if pool_batch_idx < 0: | |
| pool_batch_idx = cutlass.Int32(0) | |
| if write_pool_batch_idx < 0: | |
| write_pool_batch_idx = cutlass.Int32(0) |
| if use_pool_indexing and output_state_indices is not None: | ||
| h0_out_indices = output_state_indices.to(torch.int32) | ||
| else: | ||
| h0_out_indices = h0_indices |
There was a problem hiding this comment.
The use_pool_indexing check here is redundant. The public API gated_delta_rule_decode_pretranspose already asserts that output_state_indices can only be provided when use_pool_indexing is true.
You can simplify this logic for better readability.
| if use_pool_indexing and output_state_indices is not None: | |
| h0_out_indices = output_state_indices.to(torch.int32) | |
| else: | |
| h0_out_indices = h0_indices | |
| if output_state_indices is not None: | |
| h0_out_indices = output_state_indices.to(torch.int32) | |
| else: | |
| h0_out_indices = h0_indices |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 195-206: The output_state_indices path currently allows in-place
remaps that can alias other batch items' source slots, making final state
CTA-order dependent; in the block that checks output_state_indices (and uses
use_pool and initial_state / initial_state_indices), validate that
output_state_indices contains no duplicate targets and that none of its target
indices overlap any indices in initial_state_indices (or raise a clear error);
alternatively implement a staged fallback: allocate a temporary buffer, gather
sources into temp using initial_state_indices, perform compute, then scatter
results from temp to initial_state using output_state_indices to avoid
read/write races. Ensure checks/reference to output_state_indices,
initial_state, initial_state_indices and use_pool are used so the change locates
the remap logic.
- Around line 195-206: The code currently only checks shape/dtype of
output_state_indices; add validation that output_state_indices is on the same
device as the pool (reject CPU/non-local tensors) and that all values are within
[0, pool_size-1] to prevent out-of-bounds or aliasing when writing into the pool
(when use_pool/initial_state is active). In the gdn_decode logic where
output_state_indices is handled (the block that asserts use_pool and checks
shape/dtype), add checks for device equality to the pool tensor and use
torch.any((idx < 0) | (idx >= pool_size)) or equivalent to raise a clear
ValueError/Assertion if any index is out of range; keep references to
output_state_indices, use_pool, pool_size, and initial_state so the guard runs
early and fails fast.
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1756-1767: The tests currently use torch.testing.assert_close to
check that pool_under_test[read_indices] and pool_under_test[~used_mask] match
pool_orig with nonzero atol/rtol; change these to exact-equality checks (e.g.,
use torch.equal or torch.testing.assert_close(..., atol=0, rtol=0)) for the two
assertions involving pool_under_test, pool_orig, read_indices and the computed
used_mask/write_indices so any stray mutation is caught.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 748e9c86-0f49-4322-be15-919a3f8c3f91
📒 Files selected for processing (4)
flashinfer/gdn_decode.pyflashinfer/gdn_kernels/gdn_decode_bf16_state.pyflashinfer/gdn_kernels/gdn_decode_pretranspose.pytests/gdn/test_decode_delta_rule.py
| if output_state_indices is not None: | ||
| assert use_pool, ( | ||
| "output_state_indices can only be used with initial_state (pool mode)" | ||
| ) | ||
| assert output_state_indices.shape == (B,), ( | ||
| f"Expected output_state_indices shape [{B}], " | ||
| f"got {output_state_indices.shape}" | ||
| ) | ||
| assert output_state_indices.dtype in (torch.int32, torch.int64), ( | ||
| f"output_state_indices must be int32 or int64, " | ||
| f"got {output_state_indices.dtype}" | ||
| ) |
There was a problem hiding this comment.
Reject in-place remaps that alias another batch item's source slot.
output_state_indices still writes back into the same initial_state buffer during the same kernel launch. If two batch items target the same write slot, or one item writes a slot another item is still reading via initial_state_indices, the final state becomes CTA-order dependent and no longer matches gather→compute→scatter semantics. Please either validate a safe mapping here or route overlapping remaps through a staged fallback.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_decode.py` around lines 195 - 206, The output_state_indices
path currently allows in-place remaps that can alias other batch items' source
slots, making final state CTA-order dependent; in the block that checks
output_state_indices (and uses use_pool and initial_state /
initial_state_indices), validate that output_state_indices contains no duplicate
targets and that none of its target indices overlap any indices in
initial_state_indices (or raise a clear error); alternatively implement a staged
fallback: allocate a temporary buffer, gather sources into temp using
initial_state_indices, perform compute, then scatter results from temp to
initial_state using output_state_indices to avoid read/write races. Ensure
checks/reference to output_state_indices, initial_state, initial_state_indices
and use_pool are used so the change locates the remap logic.
Validate output_state_indices against the pool before dispatch.
The new arg is only shape/dtype-checked. A CPU tensor here will fail late, and a negative or >= pool_size write index can either become an out-of-bounds store on the float32 pretranspose path or silently alias slot 0 on the bf16 path. Please reject non-local or out-of-range write indices here unless you want explicit write-side padding semantics.
💡 Suggested guard
if output_state_indices is not None:
assert use_pool, (
"output_state_indices can only be used with initial_state (pool mode)"
)
assert output_state_indices.shape == (B,), (
f"Expected output_state_indices shape [{B}], "
f"got {output_state_indices.shape}"
)
assert output_state_indices.dtype in (torch.int32, torch.int64), (
f"output_state_indices must be int32 or int64, "
f"got {output_state_indices.dtype}"
)
+ assert output_state_indices.device == initial_state.device, (
+ "output_state_indices must be on the same device as initial_state"
+ )
+ pool_size = int(initial_state.shape[0])
+ in_range = (output_state_indices >= 0) & (output_state_indices < pool_size)
+ assert in_range.all().item(), (
+ f"output_state_indices must be in [0, {pool_size})"
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_decode.py` around lines 195 - 206, The code currently only
checks shape/dtype of output_state_indices; add validation that
output_state_indices is on the same device as the pool (reject CPU/non-local
tensors) and that all values are within [0, pool_size-1] to prevent
out-of-bounds or aliasing when writing into the pool (when
use_pool/initial_state is active). In the gdn_decode logic where
output_state_indices is handled (the block that asserts use_pool and checks
shape/dtype), add checks for device equality to the pool tensor and use
torch.any((idx < 0) | (idx >= pool_size)) or equivalent to raise a clear
ValueError/Assertion if any index is out of range; keep references to
output_state_indices, use_pool, pool_size, and initial_state so the guard runs
early and fails fast.
| # Read slots must be unchanged (we wrote to different slots) | ||
| torch.testing.assert_close( | ||
| pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol | ||
| ) | ||
|
|
||
| # Other slots must be unchanged | ||
| used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device) | ||
| used_mask[read_indices] = True | ||
| used_mask[write_indices] = True | ||
| torch.testing.assert_close( | ||
| pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol | ||
| ) |
There was a problem hiding this comment.
Use exact equality for slots that must stay untouched.
These assertions are checking for no mutation, not numerical closeness. Keeping atol/rtol=1e-3 can hide a small stray write, so the read slots and the untouched remainder should be compared with zero tolerance.
💡 Tighten the unchanged-slot assertions
# Read slots must be unchanged (we wrote to different slots)
torch.testing.assert_close(
- pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol
+ pool_under_test[read_indices],
+ pool_orig[read_indices],
+ atol=0.0,
+ rtol=0.0,
)
# Other slots must be unchanged
used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device)
used_mask[read_indices] = True
used_mask[write_indices] = True
torch.testing.assert_close(
- pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol
+ pool_under_test[~used_mask],
+ pool_orig[~used_mask],
+ atol=0.0,
+ rtol=0.0,
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gdn/test_decode_delta_rule.py` around lines 1756 - 1767, The tests
currently use torch.testing.assert_close to check that
pool_under_test[read_indices] and pool_under_test[~used_mask] match pool_orig
with nonzero atol/rtol; change these to exact-equality checks (e.g., use
torch.equal or torch.testing.assert_close(..., atol=0, rtol=0)) for the two
assertions involving pool_under_test, pool_orig, read_indices and the computed
used_mask/write_indices so any stray mutation is caught.
|
Approved conditional on CI. Edit: though it seems I don't have write access so will need another reviewer to take a look |
|
/bot run |
|
[FAILED] Pipeline #47476702: 7/20 passed |
|
@saltyminty thank you for the review! I see flashinfer-bot is reporting failed CI pipeline. I am happy to address any issues, but I don't have access to the CI results. |
|
@kahyunnam thank you for the review! @kahyunnam @saltyminty what should I do to land this PR? |
93687a1 to
dc5af70
Compare
|
In the meanwhile, I rebased to the latest main and adjusted gdn_decode_bf16state_mtp_kernel to support the change. |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2549-2553: When defaulting output_state_indices (when
output_state_indices is None), preserve padding/null-buffer semantics by cloning
initial_state_indices but mapping padded markers (-1) back to the fallback write
slot (e.g., 0) before use; specifically, in the block handling
output_state_indices, set output_state_indices = initial_state_indices.clone(),
then replace any entries equal to -1 with 0, and finally ensure dtype is
torch.int32. This keeps the kernel's h0_out_indices behavior correct (padded
reads won't write to -1 locations) while keeping the int32 conversion logic.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 63235c41-a9c7-4f65-a721-9399edaa4c08
📒 Files selected for processing (4)
flashinfer/gdn_decode.pyflashinfer/gdn_kernels/gdn_decode_bf16_state.pyflashinfer/gdn_kernels/gdn_decode_pretranspose.pytests/gdn/test_decode_delta_rule.py
🚧 Files skipped from review as they are similar to previous changes (2)
- flashinfer/gdn_decode.py
- flashinfer/gdn_kernels/gdn_decode_pretranspose.py
| # Default output indices to read indices | ||
| if output_state_indices is None: | ||
| output_state_indices = initial_state_indices | ||
| elif output_state_indices.dtype != torch.int32: | ||
| output_state_indices = output_state_indices.to(torch.int32) |
There was a problem hiding this comment.
Preserve padding/null-buffer semantics when defaulting output_state_indices.
This regresses the existing BF16 negative-index path: padded reads still come in as initial_state_indices == -1, but None now copies that -1 straight onto the write side. The kernel uses h0_out_indices for final writeback, so padded rows now write before h0_source instead of falling back to slot 0.
🐛 Minimal fix
- if output_state_indices is None:
- output_state_indices = initial_state_indices
- elif output_state_indices.dtype != torch.int32:
- output_state_indices = output_state_indices.to(torch.int32)
+ if output_state_indices is None:
+ # Preserve the existing slot-0 null-buffer behavior for padded rows.
+ output_state_indices = initial_state_indices.clamp_min(0)
+ if output_state_indices.dtype != torch.int32:
+ output_state_indices = output_state_indices.to(torch.int32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2549 - 2553,
When defaulting output_state_indices (when output_state_indices is None),
preserve padding/null-buffer semantics by cloning initial_state_indices but
mapping padded markers (-1) back to the fallback write slot (e.g., 0) before
use; specifically, in the block handling output_state_indices, set
output_state_indices = initial_state_indices.clone(), then replace any entries
equal to -1 with 0, and finally ensure dtype is torch.int32. This keeps the
kernel's h0_out_indices behavior correct (padded reads won't write to -1
locations) while keeping the int32 conversion logic.
dc5af70 to
c8d21f0
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (4)
tests/gdn/test_decode_delta_rule.py (1)
2182-2193:⚠️ Potential issue | 🟡 MinorUse exact equality for untouched slots.
These assertions are checking for no mutation, not closeness. Keeping
atol/rtol=1e-3can hide a small stray write in the read slots or the untouched remainder.Tighten the unchanged-slot checks
# Read slots must be unchanged (we wrote to different slots) torch.testing.assert_close( - pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol + pool_under_test[read_indices], + pool_orig[read_indices], + atol=0.0, + rtol=0.0, ) # Other slots must be unchanged used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device) used_mask[read_indices] = True used_mask[write_indices] = True torch.testing.assert_close( - pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol + pool_under_test[~used_mask], + pool_orig[~used_mask], + atol=0.0, + rtol=0.0, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` around lines 2182 - 2193, Change the "unchanged slots" checks to require exact equality instead of approximate closeness: for the read-only check comparing pool_under_test[read_indices] and pool_orig[read_indices] and for the unused-slot check comparing pool_under_test[~used_mask] and pool_orig[~used_mask], replace the current torch.testing.assert_close usage (which uses atol/rtol) with an exact-equality assertion (e.g., use torch.testing.assert_close with rtol=0 and atol=0 or torch.testing.assert_equal / torch.equal) so any stray writes to pool_under_test (referenced by pool_under_test, pool_orig, read_indices, write_indices, used_mask) will fail the test.flashinfer/gdn_kernels/gdn_decode_bf16_state.py (1)
2549-2553:⚠️ Potential issue | 🔴 CriticalPreserve null-buffer semantics when defaulting write indices.
When
output_state_indicesis omitted, this copies-1padding markers frominitial_state_indicesonto the write side. The kernel clampscache_idx, but final writeback still useswrite_cache_idx, so padded rows can store before the pool instead of slot 0.Minimal fix
- if output_state_indices is None: - output_state_indices = initial_state_indices - elif output_state_indices.dtype != torch.int32: - output_state_indices = output_state_indices.to(torch.int32) + if output_state_indices is None: + # Preserve the slot-0 null-buffer behavior for padded rows. + output_state_indices = initial_state_indices.clamp_min(0) + if output_state_indices.dtype != torch.int32: + output_state_indices = output_state_indices.to(torch.int32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2549 - 2553, The bug is that defaulting output_state_indices to initial_state_indices copies -1 padding markers to the write side, allowing padded rows to be written; instead, when output_state_indices is None set it to a new int32 tensor of the same shape filled with -1 (preserving null-buffer/write-disabled semantics) rather than aliasing initial_state_indices; ensure subsequent dtype logic still converts tensors to torch.int32 (use torch.full_like(initial_state_indices, -1, dtype=torch.int32) or torch.full(initial_state_indices.shape, -1, dtype=torch.int32) for the assignment to output_state_indices).flashinfer/gdn_decode.py (2)
199-210:⚠️ Potential issue | 🔴 CriticalReject unsafe pool remaps.
output_state_indicesstill allows duplicate destinations and cross-batch read/write overlap. In the in-place pool path that makes the final state CTA-order dependent instead of equivalent to gather→compute→scatter. Please reject those mappings here, or route remaps through a staged buffer.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 199 - 210, The current output_state_indices validation allows duplicate targets and cross-batch overlaps which makes the in-place pool path CTA-order dependent; update the check in the block that currently validates output_state_indices (the code around the output_state_indices assertions in gdn_decode.py) to reject any non-permutation mapping: assert that all values are in range [0, B) and that torch.unique(output_state_indices).numel() == B (i.e., no duplicates and full bijection), and raise a clear assertion/error message like "output_state_indices must be a permutation (no duplicates or cross-batch overlaps) when using pool mode"; alternatively if you prefer to support non-permutations implement a staged-buffer path (copy to a temporary buffer then scatter) and route non-permutation remaps through that path instead of the in-place pool path.
199-210:⚠️ Potential issue | 🔴 CriticalFail fast on invalid destination slots.
This still only checks shape/dtype. A CPU tensor, a negative destination, or an index
>= pool_sizecan reach the kernels and turn into an invalid state write.Suggested guard
if output_state_indices is not None: assert use_pool, ( "output_state_indices can only be used with initial_state (pool mode)" ) assert output_state_indices.shape == (B,), ( f"Expected output_state_indices shape [{B}], " f"got {output_state_indices.shape}" ) assert output_state_indices.dtype in (torch.int32, torch.int64), ( f"output_state_indices must be int32 or int64, " f"got {output_state_indices.dtype}" ) + assert output_state_indices.device == initial_state.device, ( + "output_state_indices must be on the same device as initial_state" + ) + pool_size = int(initial_state.shape[0]) + in_range = (output_state_indices >= 0) & ( + output_state_indices < pool_size + ) + assert torch.all(in_range).item(), ( + f"output_state_indices must be in [0, {pool_size})" + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 199 - 210, Add strict runtime guards for output_state_indices: when output_state_indices is not None (and use_pool is true), assert it is on the same device as the pool/initial_state (or at least a CUDA device, not CPU), and assert all indices are within [0, pool_size-1] (no negatives and none >= pool_size) before passing to kernels; use tensor operations like output_state_indices.min() and output_state_indices.max() (or torch.any checks) to detect out-of-range values and raise clear AssertionError messages referencing output_state_indices, use_pool, and pool_size so invalid destination slots cannot reach the GPU kernels.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/gdn_decode.py`:
- Around line 199-210: The current output_state_indices validation allows
duplicate targets and cross-batch overlaps which makes the in-place pool path
CTA-order dependent; update the check in the block that currently validates
output_state_indices (the code around the output_state_indices assertions in
gdn_decode.py) to reject any non-permutation mapping: assert that all values are
in range [0, B) and that torch.unique(output_state_indices).numel() == B (i.e.,
no duplicates and full bijection), and raise a clear assertion/error message
like "output_state_indices must be a permutation (no duplicates or cross-batch
overlaps) when using pool mode"; alternatively if you prefer to support
non-permutations implement a staged-buffer path (copy to a temporary buffer then
scatter) and route non-permutation remaps through that path instead of the
in-place pool path.
- Around line 199-210: Add strict runtime guards for output_state_indices: when
output_state_indices is not None (and use_pool is true), assert it is on the
same device as the pool/initial_state (or at least a CUDA device, not CPU), and
assert all indices are within [0, pool_size-1] (no negatives and none >=
pool_size) before passing to kernels; use tensor operations like
output_state_indices.min() and output_state_indices.max() (or torch.any checks)
to detect out-of-range values and raise clear AssertionError messages
referencing output_state_indices, use_pool, and pool_size so invalid destination
slots cannot reach the GPU kernels.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2549-2553: The bug is that defaulting output_state_indices to
initial_state_indices copies -1 padding markers to the write side, allowing
padded rows to be written; instead, when output_state_indices is None set it to
a new int32 tensor of the same shape filled with -1 (preserving
null-buffer/write-disabled semantics) rather than aliasing
initial_state_indices; ensure subsequent dtype logic still converts tensors to
torch.int32 (use torch.full_like(initial_state_indices, -1, dtype=torch.int32)
or torch.full(initial_state_indices.shape, -1, dtype=torch.int32) for the
assignment to output_state_indices).
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 2182-2193: Change the "unchanged slots" checks to require exact
equality instead of approximate closeness: for the read-only check comparing
pool_under_test[read_indices] and pool_orig[read_indices] and for the
unused-slot check comparing pool_under_test[~used_mask] and
pool_orig[~used_mask], replace the current torch.testing.assert_close usage
(which uses atol/rtol) with an exact-equality assertion (e.g., use
torch.testing.assert_close with rtol=0 and atol=0 or torch.testing.assert_equal
/ torch.equal) so any stray writes to pool_under_test (referenced by
pool_under_test, pool_orig, read_indices, write_indices, used_mask) will fail
the test.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: cd63ccca-d797-4d84-ac61-caebd3460c40
📒 Files selected for processing (4)
flashinfer/gdn_decode.pyflashinfer/gdn_kernels/gdn_decode_bf16_state.pyflashinfer/gdn_kernels/gdn_decode_pretranspose.pytests/gdn/test_decode_delta_rule.py
c8d21f0 to
58cc26d
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
tests/gdn/test_decode_delta_rule.py (1)
2182-2193:⚠️ Potential issue | 🟡 MinorUse exact equality for slots that must remain untouched.
These assertions verify that unmodified pool slots remain unchanged. Using
atol=1e-3, rtol=1e-3can mask small stray writes. For immutability checks, use zero tolerance.💡 Tighten the unchanged-slot assertions
# Read slots must be unchanged (we wrote to different slots) torch.testing.assert_close( - pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol + pool_under_test[read_indices], + pool_orig[read_indices], + atol=0.0, + rtol=0.0, ) # Other slots must be unchanged used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device) used_mask[read_indices] = True used_mask[write_indices] = True torch.testing.assert_close( - pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol + pool_under_test[~used_mask], + pool_orig[~used_mask], + atol=0.0, + rtol=0.0, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` around lines 2182 - 2193, The assertions that verify unmodified slots use non-zero tolerances which can hide stray writes; change the checks on pool_under_test vs pool_orig for read_indices and ~used_mask to use exact equality (e.g., torch.equal or torch.testing.assert_close with atol=0, rtol=0) so that pool_under_test[read_indices] and pool_under_test[~used_mask] must match pool_orig exactly; update the two assert calls referencing pool_under_test, pool_orig, read_indices, write_indices, and used_mask accordingly.flashinfer/gdn_kernels/gdn_decode_bf16_state.py (1)
2549-2553:⚠️ Potential issue | 🔴 CriticalPreserve padding/null-buffer semantics when defaulting
output_state_indices.When
initial_state_indicescontains-1(padding markers), directly assigning it tooutput_state_indiceswill cause the kernel to computeflat_write_idx = -1 * HV + i_hv(a negative offset) and write to invalid memory at lines 1981-2012. The read path has protection (lines 1227-1229 clamp negativecache_idxto 0), but the write path has no equivalent guard.🐛 Proposed fix: clamp negative indices to slot 0 for writes
# Default output indices to read indices if output_state_indices is None: - output_state_indices = initial_state_indices - elif output_state_indices.dtype != torch.int32: + # Preserve the existing slot-0 null-buffer behavior for padded rows. + output_state_indices = initial_state_indices.clamp(min=0) + if output_state_indices.dtype != torch.int32: output_state_indices = output_state_indices.to(torch.int32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2549 - 2553, When defaulting output_state_indices to initial_state_indices, avoid assigning the same tensor with -1 padding because the write path computes flat_write_idx and will write to negative offsets; instead, in the branch where output_state_indices is None, create a clone of initial_state_indices, replace negative values (e.g. -1) with 0 to preserve the null-slot semantics for writes, and then ensure the tensor is converted to torch.int32 (matching the existing dtype-check branch). Update the code that sets output_state_indices so it uses output_state_indices = initial_state_indices.clone(); output_state_indices[output_state_indices < 0] = 0; output_state_indices = output_state_indices.to(torch.int32) (or equivalent) so flat_write_idx cannot be negative when used with HV and i_hv.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Line 2593: The assignment to h0_out_idx_ calling from_dlpack is misformatted;
reformat that line to satisfy ruff (apply ruff format or adjust
spacing/punctuation) so it matches the project's formatting rules (e.g., proper
spacing around the = and within the function call) in the h0_out_idx_ =
from_dlpack(...) statement; keep the same variable name h0_out_idx_ and function
call from_dlpack with arguments output_state_indices, assumed_align=32,
enable_tvm_ffi=True.
---
Duplicate comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2549-2553: When defaulting output_state_indices to
initial_state_indices, avoid assigning the same tensor with -1 padding because
the write path computes flat_write_idx and will write to negative offsets;
instead, in the branch where output_state_indices is None, create a clone of
initial_state_indices, replace negative values (e.g. -1) with 0 to preserve the
null-slot semantics for writes, and then ensure the tensor is converted to
torch.int32 (matching the existing dtype-check branch). Update the code that
sets output_state_indices so it uses output_state_indices =
initial_state_indices.clone(); output_state_indices[output_state_indices < 0] =
0; output_state_indices = output_state_indices.to(torch.int32) (or equivalent)
so flat_write_idx cannot be negative when used with HV and i_hv.
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 2182-2193: The assertions that verify unmodified slots use
non-zero tolerances which can hide stray writes; change the checks on
pool_under_test vs pool_orig for read_indices and ~used_mask to use exact
equality (e.g., torch.equal or torch.testing.assert_close with atol=0, rtol=0)
so that pool_under_test[read_indices] and pool_under_test[~used_mask] must match
pool_orig exactly; update the two assert calls referencing pool_under_test,
pool_orig, read_indices, write_indices, and used_mask accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 3392082e-e475-483c-bc9d-773a2995d6ff
📒 Files selected for processing (4)
flashinfer/gdn_decode.pyflashinfer/gdn_kernels/gdn_decode_bf16_state.pyflashinfer/gdn_kernels/gdn_decode_pretranspose.pytests/gdn/test_decode_delta_rule.py
✅ Files skipped from review due to trivial changes (1)
- flashinfer/gdn_kernels/gdn_decode_pretranspose.py
🚧 Files skipped from review as they are similar to previous changes (1)
- flashinfer/gdn_decode.py
| dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True) | ||
| o_ = from_dlpack(output, assumed_align=32, enable_tvm_ffi=True) | ||
| h0_idx_ = from_dlpack(initial_state_indices, assumed_align=32, enable_tvm_ffi=True) | ||
| h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True) |
There was a problem hiding this comment.
Fix formatting to pass pre-commit checks.
The pipeline failure indicates this line needs reformatting per ruff format.
🔧 Apply ruff formatting
- h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True)
+ h0_out_idx_ = from_dlpack(
+ output_state_indices, assumed_align=32, enable_tvm_ffi=True
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True) | |
| h0_out_idx_ = from_dlpack( | |
| output_state_indices, assumed_align=32, enable_tvm_ffi=True | |
| ) |
🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 2590-2593: pre-commit failed: ruff format (hook id: ruff-format) reformatted files. Diff shows formatting change in gated_delta_rule_mtp() for h0_out_idx_ = from_dlpack(output_state_indices, ...).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` at line 2593, The assignment
to h0_out_idx_ calling from_dlpack is misformatted; reformat that line to
satisfy ruff (apply ruff format or adjust spacing/punctuation) so it matches the
project's formatting rules (e.g., proper spacing around the = and within the
function call) in the h0_out_idx_ = from_dlpack(...) statement; keep the same
variable name h0_out_idx_ and function call from_dlpack with arguments
output_state_indices, assumed_align=32, enable_tvm_ffi=True.
|
@saltyminty should I do anything here? |
|
@feldsherov this is failing pre-commit tests: https://github.com/flashinfer-ai/flashinfer/actions/runs/24164935099/job/70524339364?pr=2905 Can you please rerun pre-commit and push? |
📌 Description
Introduce separate output indices parameter for gated_delta_rule_decode_pretranspose.
This addresses decoded part of feature request in #2873
🔍 Related Issues
#2873
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).I've checked only tests/gdn/test_decode_delta_rule.py on H200. I need help with running whole testsuite.
Reviewer Notes
Summary by CodeRabbit